import fmt
import crypto.blowfish
import time
import libc

// bcrypt constant definitions
const MIN_COST = 4
const MAX_COST = 31
const DEFAULT_COST = 10

const MAJOR_VERSION = 50 // '2'
const MINOR_VERSION = 97 // 'a'
const MAX_SALT_SIZE = 16
const MAX_CRYPTED_HASH_SIZE = 23
const ENCODED_SALT_SIZE = 22
const ENCODED_HASH_SIZE = 31
const MIN_HASH_SIZE = 59

const DOLLAR_CHAR = 36 // $
const ZERO_CHAR = 48 // '0'
const NINE_CHAR = 57 // '9'

// bcrypt-specific base64 alphabet
const ALPHABET = "./ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"

// Magic cipher data - big-endian bytes of "OrpheanBeholderScryDoubt"
[u8] MAGIC_CIPHER_DATA = [
    0x4f, 0x72, 0x70, 0x68,
    0x65, 0x61, 0x6e, 0x42,
    0x65, 0x68, 0x6f, 0x6c,
    0x64, 0x65, 0x72, 0x53,
    0x63, 0x72, 0x79, 0x44,
    0x6f, 0x75, 0x62, 0x74,
]

// Error definitions for throw
const ERROR_PASSWORD_TOO_LONG = "password too long"
const ERROR_INVALID_COST = "invalid cost parameter"
const ERROR_HASH_TOO_SHORT = "hash too short"
const ERROR_INVALID_HASH_PREFIX = "invalid hash prefix"
const ERROR_HASH_VERSION_TOO_NEW = "hash version too new"
const ERROR_NOT_VERIFY = "verify failed"

// bcrypt hash structure
type bcrypt_hash_t = struct {
    [u8] hash
    [u8] salt
    int cost
    u8 major
    u8 minor
}

// bcrypt-specific base64 encoding (based on C implementation)
fn base64_encode([u8] src):[u8] {
    [u8] result = []
    int src_len = src.len()

    for int i = 0; i < src_len; i += 3 {
        u32 val = 0
        int padding = 0

        // Read up to 3 bytes
        val = (src[i] as u32) << 16
        if i + 1 < src_len {
            val |= (src[i + 1] as u32) << 8
        } else {
            padding += 1
        }
        if i + 2 < src_len {
            val |= src[i + 2] as u32
        } else {
            padding += 1
        }

        // Output corresponding number of characters (no padding characters)
        result.push(ALPHABET[(val >> 18) & 0x3F])
        result.push(ALPHABET[(val >> 12) & 0x3F])
        if padding < 2 {
            result.push(ALPHABET[(val >> 6) & 0x3F])
        }
        if padding < 1 {
            result.push(ALPHABET[val & 0x3F])
        }
    }

    return result
}

// bcrypt-specific base64 decoding (based on C implementation)
fn base64_decode([u8] src):[u8]! {
    // Character to value lookup function
    var char_to_value = fn(u8 c):int {
        for int i = 0; i < ALPHABET.len(); i += 1 {
            if ALPHABET[i] == c {
                return i
            }
        }
        return -1
    }

    // Add padding to multiple of 4
    [u8] padded_src = []
    for i, v in src {
        padded_src.push(v)
    }

    for padded_src.len() % 4 != 0 {
        padded_src.push('='[0])
    }

    [u8] result = []
    int padded_len = padded_src.len()

    for int i = 0; i < padded_len; i += 4 {
        [int] val = [0, 0, 0, 0]

        // Decode 4 characters
        for int j = 0; j < 4; j += 1 {
            if padded_src[i + j] == '='[0] {
                val[j] = 0
            } else {
                val[j] = char_to_value(padded_src[i + j])
                if val[j] == -1 {
                    throw errorf(ERROR_INVALID_HASH_PREFIX)
                }
            }
        }

        // Combine into 24-bit value
        u32 combined = ((val[0] as u32) << 18) | ((val[1] as u32) << 12) |
                      ((val[2] as u32) << 6) | (val[3] as u32)

        // Extract bytes
        result.push((combined >> 16) as u8)
        if padded_src[i + 2] != '='[0] {
            result.push((combined >> 8) as u8)
        }
        if padded_src[i + 3] != '='[0] {
            result.push(combined as u8)
        }
    }

    return result
}

// Check if cost parameter is valid
fn check_cost(int cost):void! {
    if cost < MIN_COST || cost > MAX_COST {
        throw errorf(ERROR_INVALID_COST)
    }
}

// Generate random salt
fn generate_salt():[u8] {
    [u8] salt = []

    // Use libc random functions for better randomness
    // Initialize seed with current time if not already done
    libc.srand(time.unix() as u32)

    // Generate random bytes for salt
    for int i = 0; i < MAX_SALT_SIZE; i += 1 {
        salt.push((libc.rand() % 256) as u8)
    }
    return salt
}

fn expensive_blowfish_setup([u8] key, u32 cost, [u8] salt):blowfish.cipher_t! {
    var cipher = blowfish.cipher_t{
        p: vec_new<u32>(0, 18),
        s0: vec_new<u32>(0, 256),
        s1: vec_new<u32>(0, 256),
        s2: vec_new<u32>(0, 256),
        s3: vec_new<u32>(0, 256),
    }
    
    // Decode salt
    [u8] csalt = base64_decode(salt)
    
    // Add trailing NULL byte for C implementation compatibility
    [u8] ckey = []
    for i, v in key {
        ckey.push(v)
    }
    ckey.push(0)
    
    // Create salted cipher
    int result = blowfish.new_salted_cipher(ckey, csalt, &cipher)
    if result != 0 {
        throw errorf("failed to create salted cipher")
    }
    
    // Perform expensive key expansion
    u64 rounds = 1 << cost as u64
    for u64 i = 0; i < rounds; i += 1 {
        blowfish.expand_key(ckey, &cipher)
        blowfish.expand_key(csalt, &cipher)
    }
    
    return cipher
}

// bcrypt core function
fn bcrypt([u8] password, int cost, [u8] salt):[u8]! {
    [u8] cipher_data = []
    for i, v in MAGIC_CIPHER_DATA {
        cipher_data.push(v)
    }
    
    blowfish.cipher_t c = expensive_blowfish_setup(password, cost as u32, salt)
    
    // Perform 64 rounds of encryption, processing 8 bytes each time
    for int i = 0; i < 24; i += 8 {
        for int j = 0; j < 64; j += 1 {
            [u8] block_in = []
            [u8] block_out = [0, 0, 0, 0, 0, 0, 0, 0]
            
            for int k = 0; k < 8; k += 1 {
                block_in.push(cipher_data[i + k])
            }
            
            blowfish.encrypt(c, block_out, block_in)
            
            for int k = 0; k < 8; k += 1 {
                cipher_data[i + k] = block_out[k]
            }
        }
    }
    
    [u8] hash_bytes = []
    for int i = 0; i < MAX_CRYPTED_HASH_SIZE; i += 1 {
        hash_bytes.push(cipher_data[i])
    }
    
    [u8] hsh = base64_encode(hash_bytes)
    return hsh
}

// Generate new hash from password
fn new_from_password([u8] password, int cost):bcrypt_hash_t! {
    if password.len() > 72 {
        throw errorf(ERROR_PASSWORD_TOO_LONG)
    }
    
    if cost < MIN_COST {
        cost = DEFAULT_COST
    }
    
    check_cost(cost)
    
    var hash_info = bcrypt_hash_t{
        hash: [],
        salt: [],
        cost: cost,
        major: MAJOR_VERSION as u8,
        minor: MINOR_VERSION as u8,
    }
    
    // Generate and encode salt
    [u8] unencoded_salt = generate_salt()
    hash_info.salt = base64_encode(unencoded_salt)
    
    hash_info.hash = bcrypt(password, cost, hash_info.salt)
    
    return hash_info
}

// Parse from hash string
fn new_from_hash([u8] hashed_secret):bcrypt_hash_t! {
    if hashed_secret.len() < MIN_HASH_SIZE {
        throw errorf(ERROR_HASH_TOO_SHORT)
    }
    
    var hash_info = bcrypt_hash_t{
        hash: [],
        salt: [],
        cost: 0,
        major: 0,
        minor: 0,
    }
    
    int pos = 0
    
    // Parse version
    if hashed_secret[pos] != DOLLAR_CHAR {
        throw errorf(ERROR_INVALID_HASH_PREFIX)
    }
    pos += 1
    
    if hashed_secret[pos] > (MAJOR_VERSION as u8) {
        throw errorf(ERROR_HASH_VERSION_TOO_NEW)
    }
    hash_info.major = hashed_secret[pos]
    pos += 1
    
    if hashed_secret[pos] != DOLLAR_CHAR {
        hash_info.minor = hashed_secret[pos]
        pos += 1
    }
    
    if hashed_secret[pos] != DOLLAR_CHAR {
        throw errorf(ERROR_INVALID_HASH_PREFIX)
    }
    pos += 1
    
    // Parse cost
    if pos + 2 >= hashed_secret.len() {
        throw errorf(ERROR_HASH_TOO_SHORT)
    }
    
    int cost = 0
    for int i = 0; i < 2; i += 1 {
        u8 c = hashed_secret[pos + i]
        if c >= ZERO_CHAR && c <= NINE_CHAR {
            cost = cost * 10 + (c - ZERO_CHAR) as int
        } else {
            throw errorf(ERROR_INVALID_HASH_PREFIX)
        }
    }
    pos += 2
    
    check_cost(cost)
    hash_info.cost = cost
    
    if hashed_secret[pos] != DOLLAR_CHAR {
        throw errorf(ERROR_INVALID_HASH_PREFIX)
    }
    pos += 1
    
    // Extract salt
    if pos + ENCODED_SALT_SIZE > hashed_secret.len() {
        throw errorf(ERROR_HASH_TOO_SHORT)
    }
    
    for int i = 0; i < ENCODED_SALT_SIZE; i += 1 {
        hash_info.salt.push(hashed_secret[pos + i])
    }
    pos += ENCODED_SALT_SIZE
    
    // Extract hash
    int remaining = hashed_secret.len() - pos
    for int i = 0; i < remaining; i += 1 {
        hash_info.hash.push(hashed_secret[pos + i])
    }
    
    return hash_info
}

// Generate complete hash string
fn hash_to_string(bcrypt_hash_t hash_info):[u8] {
    [u8] result = []
    
    result.push(DOLLAR_CHAR)
    result.push(hash_info.major)
    if hash_info.minor != 0 {
        result.push(hash_info.minor)
    }
    result.push(DOLLAR_CHAR)
    
    // Add cost (two digits)
    int cost = hash_info.cost
    result.push((ZERO_CHAR + (cost / 10)) as u8)
    result.push((ZERO_CHAR + (cost % 10)) as u8)
    result.push(DOLLAR_CHAR)
    
    // Add salt
    for i, v in hash_info.salt {
        result.push(v)
    }
    
    // Add hash
    for i, v in hash_info.hash {
        result.push(v)
    }
    
    return result
}

// Compare two byte arrays for equality (constant time)
fn constant_time_compare([u8] a, [u8] b):bool {
    if a.len() != b.len() {
        return false
    }
    
    u8 result = 0
    for int i = 0; i < a.len(); i += 1 {
        result |= a[i] ^ b[i]
    }
    
    return result == 0
}

// Public API functions

// Generate bcrypt hash from password
fn hash([u8] password, int cost):[u8]! {
    bcrypt_hash_t hash_info = new_from_password(password, cost)
    return hash_to_string(hash_info)
}

// Compare hash and password
fn verify([u8] hashed_password, [u8] password):void! {
    bcrypt_hash_t hash_info = new_from_hash(hashed_password)
    
    [u8] other_hash = bcrypt(password, hash_info.cost, hash_info.salt)
    
    var other_hash_info = bcrypt_hash_t{
        hash: other_hash,
        salt: hash_info.salt,
        cost: hash_info.cost,
        major: hash_info.major,
        minor: hash_info.minor,
    }
    
    [u8] hash1 = hash_to_string(hash_info)
    [u8] hash2 = hash_to_string(other_hash_info)
    
    if !constant_time_compare(hash1, hash2) {
        throw errorf(ERROR_NOT_VERIFY)
    }
}

// Get cost of hash
fn cost([u8] hashed_password):int! {
    bcrypt_hash_t hash_info = new_from_hash(hashed_password)
    return hash_info.cost
}

fn main():void! {
    println("bcrypt test started...")
    
    string password = "mypassword"
    println("Password: ", password)
    println("Cost: ", DEFAULT_COST)
    
    // Generate hash
    [u8] hashed = hash(password as [u8], DEFAULT_COST)
    println("Generated hash: ", hashed as string)
    
    // Verify password
    try {
        verify(hashed, password as [u8])
        println("✓ Password verification successful")
    } catch e {
        println("✗ Password verification failed")
    }
    
    // Test wrong password
    string wrong_password = "wrongpassword"
    try {
        verify(hashed, wrong_password as [u8])
        println("✗ Wrong password verification anomaly")
    } catch e {
        println("✓ Wrong password correctly rejected")
    }
    
    // Get cost
    try {
        int extracted_cost = cost(hashed)
        if extracted_cost == DEFAULT_COST {
            println("✓ Cost extraction correct: ", extracted_cost)
        } else {
            println("✗ Cost extraction failed")
        }
    } catch e {
        println("✗ Cost extraction failed")
    }
    
    println("bcrypt test completed")
}